{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "3dc29882",
   "metadata": {},
   "source": [
    "# Detecting audio issues in the Common Voice dataset\n",
    "This notebook aims at showing how you can leverage sliceguard to detect issues in audio datasets, using the commonvoice dataset as an example. Focus will be on the basic workflow, as well as showing how to leverage different embedding models from the huggingface hub."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ada14102",
   "metadata": {},
   "source": [
    "In order to run this example you will need some **dependencies**. Install them as follows."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3999c7a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install sliceguard librosa soundfile datasets tqdm jiwer"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d8385885",
   "metadata": {},
   "source": [
    "## Step 1: Generate predictions for the Common Voice dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ced433ff",
   "metadata": {},
   "source": [
    "**IMPORTANT NOTE**: In order to access the commonvoice dataset you have to accept certain terms and conditions. To do this, create a huggingface account and accept the terms and conditions [HERE](https://huggingface.co/datasets/mozilla-foundation/common_voice_13_0). You then need to **create an access token** to access your datasets programmatically. Follow the steps for configuring one [HERE](https://huggingface.co/docs/hub/security-tokens). It is just a matter of few minutes. Just paste your access token into a file called **access_token.txt** and place it in the same directory as this notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d9478724",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Configure this example here.\n",
    "# Like this it is optimized for fast execution only using whisper tiny.\n",
    "HF_MODEL = \"openai/whisper-tiny\"\n",
    "ACCESS_TOKEN_FILE = \"access_token.txt\"\n",
    "AUDIO_SAVE_DIR = \"audios\"\n",
    "NUM_SAMPLES = 2000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00dc6079",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Some imports your will need to execute this\n",
    "import uuid\n",
    "import shutil\n",
    "from pathlib import Path\n",
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "import torch\n",
    "import librosa\n",
    "import soundfile as sf\n",
    "from jiwer import wer\n",
    "from datasets import load_dataset, Audio\n",
    "from transformers import pipeline\n",
    "from transformers import WhisperProcessor, WhisperFeatureExtractor, WhisperTokenizer, WhisperForConditionalGeneration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15373734",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Read the acces token for downloading the dataset\n",
    "access_token = Path(ACCESS_TOKEN_FILE).read_text()\n",
    "cv_13 = load_dataset(\"mozilla-foundation/common_voice_13_0\", \"en\", use_auth_token=access_token, streaming=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "254cd94c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Instantiate an ASR pipeline with the configured model\n",
    "device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "feature_extractor = WhisperFeatureExtractor.from_pretrained(HF_MODEL)\n",
    "tokenizer = WhisperTokenizer.from_pretrained(HF_MODEL, language=\"en\", task=\"transcribe\")\n",
    "model = WhisperForConditionalGeneration.from_pretrained(HF_MODEL).to(device)\n",
    "\n",
    "model.config.forced_decoder_ids = tokenizer.get_decoder_prompt_ids() # Specify the task as we always want to use german and transcribe\n",
    "model.config.language = \"<|en|>\"\n",
    "model.config.task = \"transcribe\"\n",
    "\n",
    "pipe = pipeline(\"automatic-speech-recognition\", model=model, tokenizer=tokenizer, feature_extractor=feature_extractor, device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7af43e84",
   "metadata": {},
   "outputs": [],
   "source": [
    "keys_to_save = [\"sentence\", \"up_votes\", \"down_votes\", \"age\", \"gender\", \"accent\", \"locale\", \"segment\", \"variant\"]\n",
    "\n",
    "audio_save_dir = Path(AUDIO_SAVE_DIR)\n",
    "if  not audio_save_dir.is_dir():\n",
    "    audio_save_dir.mkdir()\n",
    "else:\n",
    "    shutil.rmtree(audio_save_dir)\n",
    "    audio_save_dir.mkdir()\n",
    "\n",
    "num_samples = 0\n",
    "data = []\n",
    "for sample in tqdm(cv_13[\"train\"], total=NUM_SAMPLES):\n",
    "    new_audio = librosa.resample(sample[\"audio\"][\"array\"], orig_sr=sample[\"audio\"][\"sampling_rate\"], target_sr=16000)\n",
    "    file_stem = str(uuid.uuid4())\n",
    "    cur_data = {}\n",
    "    for k in keys_to_save:\n",
    "        cur_data[k] = sample[k]\n",
    "    prediction = pipe(new_audio)[\"text\"]\n",
    "    cur_data[\"prediction\"] = prediction\n",
    "    \n",
    "    sample_wer = wer(sample[\"sentence\"], prediction)\n",
    "    cur_data[\"wer\"] = sample_wer\n",
    "    \n",
    "    target_path = audio_save_dir / (file_stem + \".wav\")\n",
    "    cur_data[\"audio\"] = target_path\n",
    "    sf.write(target_path, new_audio, 16000)\n",
    "    data.append(cur_data)\n",
    "    num_samples += 1\n",
    "    if num_samples > NUM_SAMPLES:\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c5d9d684",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame(data)\n",
    "df[\"audio\"] = df[\"audio\"].astype(\"string\") # otherwise overflow in serializing json\n",
    "df.to_json(\"dataset.json\", orient=\"records\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "deab58d8",
   "metadata": {},
   "source": [
    "## Step 2: Detect issues caused by environmental noise\n",
    "First check we want to do is checking wether there are audio recordings that are somehow so different from the rest of the data that they cannot be properly transcribed. Here we mostly target **general audio properties and environmental noise** such as background noises.\n",
    "\n",
    "In order to do this, we leverage **general purpose audio embeddings** of a model trained on Audioset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ddc05d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Some imports you will need for this step\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from jiwer import wer\n",
    "from sliceguard import SliceGuard\n",
    "from renumics.spotlight import Audio"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a41fbff",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Read the generated dataset including the predictions\n",
    "df = pd.read_json(\"dataset.json\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13fe8c32",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the metric function\n",
    "def wer_metric(y_true, y_pred):\n",
    "    return np.mean([wer(s_y, s_pred) for s_y, s_pred in zip(y_true, y_pred)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b131b55",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Perform an initial detection aiming for relatively small clusters of minimum 5 similar samples\n",
    "sg = SliceGuard()\n",
    "issues = sg.find_issues(\n",
    "        df,\n",
    "        [\"audio\"],\n",
    "        \"sentence\",\n",
    "        \"prediction\",\n",
    "        wer_metric,\n",
    "        metric_mode=\"min\",\n",
    "        embedding_models={\"audio\": \"MIT/ast-finetuned-audioset-10-10-0.4593\"},\n",
    "        min_support=5,\n",
    "        min_drop=0.2,\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4462ee73",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Report the issues using Renumics Spotlight\n",
    "sg.report(spotlight_dtype={\"audio\": Audio})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba55979a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Of course if you want to run additional checks you don't need to recompute the embeddings all the time.\n",
    "# Just save them here, and supply the precomputed embeddings in the next call\n",
    "# where we will check for smaller clusters aka outliers.\n",
    "computed_embeddings = sg.embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df13fff6",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Perform an additional detection, targeting outliers with significant drops (see min_support and min_drop)\n",
    "# We even allow for clusters containing single samples here.\n",
    "sg = SliceGuard()\n",
    "issues = sg.find_issues(\n",
    "        df,\n",
    "        [\"audio\"],\n",
    "        \"sentence\",\n",
    "        \"prediction\",\n",
    "        wer_metric,\n",
    "        metric_mode=\"min\",\n",
    "        min_support=1,\n",
    "        min_drop=0.3,\n",
    "        precomputed_embeddings=computed_embeddings\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ab785bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Report the issues using Renumics Spotlight\n",
    "sg.report(spotlight_dtype={\"audio\": Audio})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e8b6d287",
   "metadata": {},
   "source": [
    "## Step 3: Detect issues caused by (uncommon) speakers\n",
    "While the previous detection example targeted finding general audio conditions that can cause issues, this is not always the criterion we want to check for. A way of defining other criterions is **changing the underlying embedding** to **capture different properties of the data**. In this case, we define the embedding model to be a model for **speaker identification**. This should allow us, to **detect uncommon speakers**, although they are note explicitely labeled."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "408a0fb7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Perform a detection using a speaker identification model for computing embeddings.\n",
    "# This will help to recover problematic speakers even though they are not explicitely labeled.\n",
    "sg = SliceGuard()\n",
    "issues = sg.find_issues(\n",
    "        df,\n",
    "        [\"audio\"],\n",
    "        \"sentence\",\n",
    "        \"prediction\",\n",
    "        wer_metric,\n",
    "        metric_mode=\"min\",\n",
    "        embedding_models={\"audio\": \"superb/wav2vec2-base-superb-sid\"},\n",
    "        min_support=1,\n",
    "        min_drop=0.3,\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5faee9cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Report the issues using Renumics Spotlight\n",
    "sg.report(spotlight_dtype={\"audio\": Audio})"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
